-
Notifications
You must be signed in to change notification settings - Fork 31k
🔴🔴🔴 [Attention] Refactor Attention Interface for Bart-based Models
#38108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Attention] Refactor Attention Interface and Enable Flex Attention Attention] Refactor Attention Interface for Bart-based Modelsand Enable Flex Attention
Attention] Refactor Attention Interface for Bart-based Modelsand Enable Flex Attention Attention] Refactor Attention Interface for Bart-based Models and Enable Flex Attention
|
PR is ready again, will look whether some flags have been missed but it's good otherwise. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔴 for the abstraction of the attention interface!
Let's keep the core logic explicit, and just put each warning in each integration/sdpa|flash|flex.py file!
| attn_mask=causal_mask, | ||
| dropout_p=self.dropout if self.training else 0.0, | ||
| is_causal=is_causal, | ||
| attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no here we want one explicit thing, that the default attention is eager_attention_forward!
The core philosophy when we abstract is we keep the inderections to a minimal, and thus the core logic (eager attention) should be explicit!
But very nice to put the rest (warning and etc) inside each function, but not in ALL_ATTENTION_FUNCTIONS's call!
Thus each sdpa, flex or flash have their own warning, we should not abstract!
| attention_mask=attention_mask, | ||
| training=self.training, | ||
| dropout=self.dropout, | ||
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] | |
| attention_interface = eager_attn_forward | |
| if self.config._attn_implementation != "eager": | |
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
let's make it explicit what is the default!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're too fast :D will change in a sec, wouldnt work either way with attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] since eager is not registered in the interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha yep exactly
Attention] Refactor Attention Interface for Bart-based Models and Enable Flex Attention Attention] Refactor Attention Interface for Bart-based Models and Enable Flex Attention
tests/test_modeling_common.py
Outdated
| config._attn_implementation = "flex_attention" | ||
| model = model_class(config).to(device=torch_device, dtype=torch.float16) | ||
| # Flex Attention can not use dropout | ||
| if hasattr(config, "attention_droput"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to fix this, typo (discovered during whisper)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rerunning flex tests to see which models will fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a lot of models fail on flex attention... Not sure if I should submit a torch issue. Disabled flex on them for now - I don't think it's a high priority atm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kudos to you! Let's break a bit and warn without fallback WDYT? not adding a new arg!
| "Falling back to eager attention because `flash_attention_2` does not support" | ||
| " `output_attentions=True` or `head_mask`." | ||
| ) | ||
| return eager_fallback( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if eager fallback is none we are failing hard! Let's maybe just emit the warning no? And return None output attentions
| "Falling back to eager attention because `flex_attention` does not support" | ||
| " `output_attentions=True`, `head_mask`, or `dropout`." | ||
| ) | ||
| return eager_fallback( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here! And head mask is something we kinda deprecated so let's just return None prob
Attention] Refactor Attention Interface for Bart-based Models and Enable Flex Attention Attention] Refactor Attention Interface for Bart-based Models
|
Another day, another merge conflict |
|
Checked offline again with Arthur, merging |
This PR is gonna tackle two things in general:
Affected models (will be updated when I have enough time) - probably not 100% accurate
Possibly doable in this PR:
Worth a discussion (?):
Worth a discussion TL;DR:
Future PRs will address: